import albumentations
import random
import numpy as np
from PIL import Image
import cv2
from io import BytesIO

class SimplePreprocessor(object):
    def __init__(self,
                 size=None, 
                 random_crop=False, 
                 horizon_flip=False,
                 change_brightness=False,
                 add_noise=False,
                 random_rotate=False,
                 smallest_max_size=None,
                 additional_targets=None):
        """
        This image preprocessor is implemented based on `albumentations`
        """
        if isinstance(size, int): 
            size = [size, size] # height, width
        if smallest_max_size is None:
            smallest_max_size = min(size)
        self.size = tuple(size)
        if size is not None and min(size) > 0:
            transforms = list()
            rescaler = albumentations.SmallestMaxSize(max_size=smallest_max_size)
            transforms.append(rescaler)
            if not random_crop:
                cropper = albumentations.CenterCrop(height=size[0], width=size[1])
                transforms.append(cropper)
            else:
                cropper = albumentations.RandomCrop(height=size[0], width=size[1])
                transforms.append(cropper)
            if horizon_flip:
                flipper = albumentations.HorizontalFlip()
                transforms.append(flipper)
            if change_brightness:
                raise RuntimeError('There is a bug in this augmentation, please do not use it before fix it!')
                brightness = albumentations.RandomBrightnessContrast(p=0.2)
                transforms.append(brightness)
            if add_noise:
                raise RuntimeError('There is a bug in this augmentation, please do not use it before fix it!')
                noise = albumentations.OneOf([
                            albumentations.IAAAdditiveGaussianNoise(),
                            albumentations.GaussNoise(),
                        ], p=0.2)
                transforms.append(noise)
            if random_rotate:
                rotate = albumentations.ShiftScaleRotate(shift_limit=0.0, scale_limit=0.0, rotate_limit=20, p=0.2)
                transforms.append(rotate)

            preprocessor = albumentations.Compose(transforms,
                                                additional_targets=additional_targets)
        else:
            preprocessor = lambda **kwargs: kwargs            

        self.preprocessor = preprocessor
    
    def __call__(self, **input):
        return self.preprocessor(**input)

class IPTPretrainedPreprocesser(object):
    def __init__(self, size=None, crop_size=None):
        self.isCrop = (crop_size is not None and crop_size > 0)
        self.patch_size = size if not self.isCrop else crop_size
        # SR *2, *3, *4
        if size is not None:
            self.resizer = albumentations.Compose([albumentations.Resize(size, size, cv2.INTER_LINEAR)])
        else:
            self.resizer = None
        if self.isCrop:
            self.crop = albumentations.Compose([albumentations.RandomCrop(height=crop_size,width=crop_size)])
            self.sr_resizer = albumentations.Compose([albumentations.Resize(crop_size, crop_size, cv2.INTER_LINEAR)])
        else:
            self.sr_resizer = albumentations.Compose([albumentations.Resize(size, size, cv2.INTER_LINEAR)])
        # add noise
        self.noise = albumentations.Compose([albumentations.GaussNoise(var_limit=(100, 5625), p=1.0)])
        # add blur
        self.blur = albumentations.Compose([albumentations.GaussianBlur(blur_limit=(3, 15), p=1.0)])
        # add rain
        self.rain = albumentations.Compose([albumentations.RandomRain(brightness_coefficient=0.9, blur_value=5, drop_width=1, p=1.0)])

    def image_compress(self, img_np, formats='jpeg', quality=75):
        if np.max(img_np) < 2:
            img_np = ((img_np + 1.0) * 127.5).astype(np.uint8)
        img = Image.fromarray(np.uint8(img_np))
        out = BytesIO()
        img.save(out, format=formats,quality=quality)
        out = out.getvalue()
        out_img = Image.open(BytesIO(out))
        out_img = np.array(out_img).astype(np.uint8)
        return out_img

    def add_gaussnoise(self, img_np, var, mean=0):
        size = img_np.shape
        noise = np.random.randn(size[0], size[1]) * var + mean
        if len(size) == 3:
            noise = noise[..., None]
        noise_img_np = img_np.astype(np.float32) + noise
        noise_img_np = np.clip(noise_img_np, 0, 255)
        return noise_img_np

    def __call__(self, input):
        # get the resize and random crop image as the original image
        if self.resizer is not None:
            resize_image = self.resizer(image=input)["image"]
        else:
            resize_image = input
        if self.isCrop:
            if resize_image.shape[0] < self.patch_size or resize_image.shape[1] < self.patch_size:
                resize_image = albumentations.augmentations.functional.resize(resize_image, height=self.patch_size, 
                            width=self.patch_size, interpolation=cv2.INTER_LINEAR)
            crop_image = self.crop(image=resize_image)["image"]
        else:
            crop_image = resize_image
        out_crop = np.transpose(crop_image.astype(np.float32), (2, 0, 1))
        # random SR
        random_low_rate = random.randint(2, 4)
        low_image = albumentations.augmentations.functional.resize(crop_image, height=int(self.patch_size / random_low_rate), 
                            width=int(self.patch_size / random_low_rate), interpolation=cv2.INTER_LINEAR)
        low_image = self.sr_resizer(image=low_image)["image"]
        out_sr = np.transpose(low_image.astype(np.float32), (2, 0, 1))
        # random noise
        var_value = random.randint(100, 5625)
        noi_image = self.add_gaussnoise(crop_image, var_value)
        # noi_image = self.noise(image=crop_image.astype(np.uint8))["image"]
        out_noise = np.transpose(noi_image.astype(np.float32), (2, 0, 1))
        # random blur
        blur_image = self.blur(image=crop_image.astype(np.uint8))["image"]
        out_blur = np.transpose(blur_image.astype(np.float32), (2, 0, 1))
        # random compress with jpeg
        random_quality = random.randint(10, 20)
        cmp_image = self.image_compress(crop_image.astype(np.uint8), quality=random_quality)
        out_cmp = np.transpose(cmp_image.astype(np.float32), (2, 0, 1))
        # random rain
        rain_image = self.rain(image=crop_image.astype(np.uint8))["image"]
        out_rain = np.transpose(rain_image.astype(np.float32), (2, 0, 1))
        out = {'sr': out_sr, 'noise': out_noise, 'blur': out_blur, 'cmp': out_cmp, 'rain': out_rain, 'image': out_crop}
        return out

class IPTSRPretrainedPreprocesser(object):
    def __init__(self, patch_size=None, scale=2, input_large=False):
        self.patch_size = patch_size
        self.sr_size = patch_size // 2
        self.input_large = input_large
        self.croper = albumentations.Compose([albumentations.RandomCrop(height=patch_size, width=patch_size)]) 
        self.resizer = albumentations.Compose([albumentations.Resize(patch_size, patch_size, cv2.INTER_LINEAR)])
        self.down_resizer = albumentations.Compose([albumentations.Resize(self.sr_size, self.sr_size, cv2.INTER_LINEAR)]) 

    def augment(self, *args, hflip=True, rot=True):
        hflip = hflip and random.random() < 0.5
        vflip = rot and random.random() < 0.5
        rot90 = rot and random.random() < 0.5
        def _augment(img):
            if hflip: img = img[:, ::-1, :]
            if vflip: img = img[::-1, :, :]
            if rot90: img = img.transpose(1, 0, 2)
            return img
        return [_augment(a) for a in args]

    def __call__(self, input):
        if input.shape[0] < self.patch_size or input.shape[1] < self.patch_size:
            crop_image = self.resizer(image=input)["image"]
        else:
            crop_image = self.croper(image=input)["image"]
        lr_image = self.down_resizer(image=crop_image)["image"]
        if self.input_large:
            lr_image = self.resizer(image=lr_image)["image"]
        lr_image, hr_image = self.augment(lr_image, crop_image)

        lr_image = np.transpose(lr_image.astype(np.float32), (2, 0, 1))
        hr_image = np.transpose(hr_image.astype(np.float32), (2, 0, 1))

        return {'sr': lr_image, 'image': hr_image}
        




class DalleVAEPreprocessor(object):
    def __init__(self, 
                 size=256,
                 phase='train',
                 additional_targets=None):
    
        self.size = size 
        self.phase = phase

        self.train_preprocessor = albumentations.Compose([albumentations.RandomCrop(height=size, width=size)],
                                                   additional_targets=additional_targets)
        self.val_preprocessor = albumentations.Compose([albumentations.CenterCrop(height=size, width=size)],
                                                   additional_targets=additional_targets)                                                


    def __call__(self, image, **kargs):
        """
        image: PIL.Image
        """
        if isinstance(image, np.ndarray):
            # import pdb; pdb.set_trace()
            image = Image.fromarray(image.astype(np.uint8)) 
        
        w, h = image.size
        s_min = min(h, w)

        if self.phase == 'train':
            # random crop
            x_off, y_off = 0, 0
            if w > s_min:
                x_off = random.randint(0, w-s_min)
            if h > s_min:
                y_off = random.randint(0, h-s_min)
            image = image.crop((x_off, y_off, x_off+s_min, y_off+s_min))

            t_min = min(s_min, round(9/8 * self.size))
            t_max = min(s_min, round(12/8 * self.size))
            t = int(random.uniform(t_min, t_max+1))
            t = max(t, self.size)
            image = image.resize((t, t))
            image = np.array(image).astype(np.uint8)
            image = self.train_preprocessor(image=image)
        else:
            if w < h:
                w_ = self.size 
                h_ = int(h * w_/w)
            else:
                h_ = self.size
                w_ = int(w * h_/h)
            image = image.resize((w_, h_))
            image = np.array(image).astype(np.uint8)
            image = self.val_preprocessor(image=image)
        
        return image


class DalleTransformerPreprocessor(object):
    def __init__(self, 
                 size=256,
                 phase='train',
                 additional_targets=None):
    
        self.size = size 
        self.phase = phase
        # ddc: following dalle to use randomcrop
        self.train_preprocessor = albumentations.Compose([albumentations.RandomCrop(height=size, width=size)],
                                                   additional_targets=additional_targets)
        self.val_preprocessor = albumentations.Compose([albumentations.CenterCrop(height=size, width=size)],
                                                   additional_targets=additional_targets)                                                   


    def __call__(self, image, **kargs):
        """
        image: PIL.Image
        """
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image.astype(np.uint8)) 

        w, h = image.size
        s_min = min(h, w)

        if self.phase == 'train':
            off_h = int(random.uniform(3*(h-s_min)//8, max(3*(h-s_min)//8+1, 5*(h-s_min)//8)))
            off_w = int(random.uniform(3*(w-s_min)//8, max(3*(w-s_min)//8+1, 5*(w-s_min)//8)))
            # import pdb; pdb.set_trace()
            image = image.crop((off_w, off_h, off_w + s_min, off_h + s_min))

            # resize image
            t_max = min(s_min, round(9/8*self.size))
            t_max = max(t_max, self.size)
            t = int(random.uniform(self.size, t_max+1))
            image = image.resize((t, t))
            image = np.array(image).astype(np.uint8)
            image = self.train_preprocessor(image=image)
        else:
            if w < h:
                w_ = self.size 
                h_ = int(h * w_/w)
            else:
                h_ = self.size
                w_ = int(w * h_/h)
            image = image.resize((w_, h_))
            image = np.array(image).astype(np.uint8)
            image = self.val_preprocessor(image=image)
        return image


if __name__ == '__main__':

    im_path = '/home/liuqk/Program/python/image-synthesis/data/imagenet/val/n01440764/ILSVRC2012_val_00000293.JPEG'
    im = Image.open(im_path)


    a = 1

    pass 
